import torch
from torch import nn
import  numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei']    # 正常显示中文标签
plt.rcParams['axes.unicode_minus']=False      # 用来正常显示负号
word = {'family' : 'Times New Roman',
'weight' : 'normal',
'size'   : 15,
        }


a=np.load('remain_grad.npy')
b=np.load('pruned_grad.npy')
c=np.load('remain_grad_filter.npy')
d=np.load('prune_grad_filter.npy')
plt.figure()
plt.yscale('log')
plt.plot(a,label='remain_grad')
plt.plot(b,label='prune_grad')
plt.plot(c,label='remain_grad_filter')
plt.plot(d,label='prune_grad_filter')
plt.legend()
plt.show()


